{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loading data into StellarGraph from NetworkX\n",
    "\n",
    "> This demo explains how to load data from NetworkX into a form that can be used by the StellarGraph library. [See all other demos](../README.md).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/basics/loading-networkx.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/basics/loading-networkx.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[The StellarGraph library](https://github.com/stellargraph/stellargraph) supports loading graph information from NetworkX graphs. [NetworkX](https://networkx.github.io) is a library for working with graphs that provides many convenient I/O functions, graph algorithms and other tools.\n",
    "\n",
    "If your data is naturally a NetworkX graph, this is a great way to load it. If your data does not *need* to be a NetworkX graph, [loading via another route](README.md) is likely to be faster and potentially more convenient.\n",
    "\n",
    "This notebook walks through loading several kinds of graphs.\n",
    "\n",
    "- homogeneous graph without features (a homogeneous graph is one with only one type of node and one type of edge)\n",
    "- homogeneous graph with features\n",
    "- homogeneous graph with edge weights\n",
    "- directed graphs (a graph is directed if edges have a \"start\" and \"end\" nodes, instead of just connecting two nodes)\n",
    "- heterogeneous graphs (more than one node type and/or more than one edge type) with and without node features or edge weights, this includes knowledge graphs\n",
    "\n",
    "> StellarGraph supports loading data from many sources with all sorts of data preprocessing, via [Pandas](https://pandas.pydata.org) DataFrames, [NumPy](https://www.numpy.org) arrays, [Neo4j](https://neo4j.com) and [NetworkX](https://networkx.github.io) graphs. See [all loading demos](README.md) for more details.\n",
    "\n",
    "The [documentation](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.StellarGraph.from_networkx) for the `StellarGraph.from_networkx` static method includes a compressed reminder of everything discussed in this file, as well as explanations of all of the parameters.\n",
    "\n",
    "The `StellarGraph` class is available at the top level of the `stellargraph` library:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.2.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.2.1\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.2.1, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stellargraph import StellarGraph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading from many graph formats, via NetworkX\n",
    "\n",
    "A NetworkX graph can be created in many ways, including manually via `add_...` methods ([docs](https://networkx.github.io/documentation/stable/tutorial.html#creating-a-graph)), and by reading from files in many formats ([docs](https://networkx.github.io/documentation/stable/reference/readwrite/index.html)). In this tutorial, we'll only use the first one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Homogeneous graph without features\n",
    "\n",
    "To start with, we'll start with a homogeneous graph without any node features. This means the graph consists of only nodes and edges without any information other than a unique identifier.\n",
    "\n",
    "Let's use a graph representing a square with a diagonal:\n",
    "\n",
    "```\n",
    "a -- b\n",
    "| \\  |\n",
    "|  \\ |\n",
    "d -- c\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = nx.Graph()\n",
    "g.add_edge(\"a\", \"b\")\n",
    "g.add_edge(\"b\", \"c\")\n",
    "g.add_edge(\"c\", \"d\")\n",
    "g.add_edge(\"d\", \"a\")\n",
    "# diagonal\n",
    "g.add_edge(\"a\", \"c\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The basic form of constructing a `StellarGraph` from a NetworkX graphs is... just passing in that graph! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "square = StellarGraph.from_networkx(g)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `info` method ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.StellarGraph.info)) gives a high-level summary of a `StellarGraph`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  default: [4]\n",
      "    Features: none\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [5]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "print(square.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On this square, it tells us that there's 4 nodes of type `default` (a homogeneous graph still has node and edge types, but they default to `default`), with no features, and one type of edge that touches it. It also tells us that there's 5 edges of type `default` that go between nodes of type `default`. This matches what we expect: it's a graph with 4 nodes and 5 edges and one type of each.\n",
    "\n",
    "Similar to constructing via Pandas, the default node type and edge types can be set using the `node_type_default` and `edge_type_default` parameters to `StellarGraph.from_networkx(...)`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  paper: [4]\n",
      "    Features: none\n",
      "    Edge types: paper-cites->paper\n",
      "\n",
      " Edge types:\n",
      "    paper-cites->paper: [5]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "square_named = StellarGraph.from_networkx(\n",
    "    g, node_type_default=\"paper\", edge_type_default=\"cites\"\n",
    ")\n",
    "print(square_named.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Homogeneous graph with features\n",
    "\n",
    "For many real-world problems, we have more than just graph structure: we have information about the nodes and edges. For instance, we might have a graph of academic papers (nodes) and how they cite each other (edges): we might have information about the nodes such as the authors and the publication year, and even the abstract or full paper contents. If we're doing a machine learning task, it can be useful to feed this information into models. The `StellarGraph.from_networkx` class supports this in two ways:\n",
    "\n",
    "1. loading from an attribute, which stores a numeric sequence\n",
    "2. using a Pandas DataFrame (this is the same as the `StellarGraph(...)` constructor from [the \"loading from Pandas\" tutorial](loading-pandas.ipynb))\n",
    "\n",
    "Let's continue using the same graph.\n",
    "\n",
    "### 1. Loading from an attribute\n",
    "\n",
    "If the nodes of our graph comes or can be augmented with, a feature attribute that contains a numeric sequence (such as a list or a NumPy array), `StellarGraph.from_networkx` can load these to create node features.\n",
    "\n",
    "The feature attributes can be assigned in many ways, such as via `some_graph.nodes[node_id][feature_name] = ...` or by iterating over the nodes. We'll do the second one here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "({'feature': [97, 1]}, {'feature': [99, 1]})"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g_feature_attr = g.copy()\n",
    "\n",
    "\n",
    "def compute_features(node_id):\n",
    "    # in general this could compute something based on other features, but for this example,\n",
    "    # we don't have any other features, so we'll just do something basic with the node_id\n",
    "    return [ord(node_id), len(node_id)]\n",
    "\n",
    "\n",
    "for node_id, node_data in g_feature_attr.nodes(data=True):\n",
    "    node_data[\"feature\"] = compute_features(node_id)\n",
    "\n",
    "# let's see what some of them look like:\n",
    "g_feature_attr.nodes[\"a\"], g_feature_attr.nodes[\"c\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `node_features=...` parameter let's us tell `from_networkx` how to find the features. If it's a string, it looks for an attribute by that name in each node of the graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  default: [4]\n",
      "    Features: float32 vector, length 2\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [5]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "square_feature_attr = StellarGraph.from_networkx(g_feature_attr, node_features=\"feature\")\n",
    "print(square_feature_attr.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice how `info` now says that nodes of type `default` (all of them) have a feature vector of length 2. Success!\n",
    "\n",
    "In a homogeneous graph like this, the features for every node need to have the same length."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Using a Pandas DataFrame\n",
    "\n",
    "Another option is to have a Pandas DataFrame of features. This is often more efficient, if the data comes separately to the graph structure, or if significant preprocessing is required before creating the `StellarGraph`.\n",
    "\n",
    "The structure of the dataframe is the same as the nodes DataFrame used for the main `StellarGraph(...)` constructor ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.StellarGraph)), which is covered in more detail in [the \"loading from Pandas\" tutorial](loading-pandas.ipynb). That tutorial includes examples of loading the DataFrame from a file. In this tutorial, we will just work with DataFrames that have already been loaded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>a</th>\n",
       "      <td>1</td>\n",
       "      <td>-0.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>b</th>\n",
       "      <td>2</td>\n",
       "      <td>0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>c</th>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>d</th>\n",
       "      <td>4</td>\n",
       "      <td>-0.5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   x    y\n",
       "a  1 -0.2\n",
       "b  2  0.3\n",
       "c  3  0.0\n",
       "d  4 -0.5"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "features = pd.DataFrame(\n",
    "    {\"x\": [1, 2, 3, 4], \"y\": [-0.2, 0.3, 0.0, -0.5]}, index=[\"a\", \"b\", \"c\", \"d\"]\n",
    ")\n",
    "features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice how the IDs we used for the nodes in the NetworkX graph are the DataFrame's index. The index is how the features are connected to each node, and the nodes in the graph and nodes in the DataFrame need to match exactly.\n",
    "\n",
    "With a DataFrame in the appropriate format, we can pass this to the `node_features=...` parameter too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  default: [4]\n",
      "    Features: float32 vector, length 2\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [5]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "square_feature_dataframe = StellarGraph.from_networkx(\n",
    "    g_feature_attr, node_features=features\n",
    ")\n",
    "print(square_feature_dataframe.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Like with the attribute, `info` now says that our nodes have a feature vector of length 2.\n",
    "\n",
    "We can use Pandas to do all sorts of feature preprocessing, like the column filtering and one-hot encoding done in the other tutorial."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Homogeneous graph with edge weights\n",
    "\n",
    "Some algorithms can understand edge weights, which can be used as a measure of the strength of the connection, or a measure of distance between nodes. A `StellarGraph` instance can have weighted edges, by specifying a `weight` attribute on the edges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  default: [4]\n",
      "    Features: none\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [5]\n",
      "        Weights: range=[1, 5], mean=3, std=1.58114\n"
     ]
    }
   ],
   "source": [
    "g_weighted = nx.Graph()\n",
    "g_weighted.add_edge(\"a\", \"b\", weight=1.0)\n",
    "g_weighted.add_edge(\"b\", \"c\", weight=2.0)\n",
    "g_weighted.add_edge(\"c\", \"d\", weight=3.0)\n",
    "g_weighted.add_edge(\"d\", \"a\", weight=4.0)\n",
    "# diagonal\n",
    "g_weighted.add_edge(\"a\", \"c\", weight=5.0)\n",
    "\n",
    "square_weighted = StellarGraph.from_networkx(g_weighted)\n",
    "print(square_weighted.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice the output of `info` now shows additional statistics about edge weights.\n",
    "\n",
    "The name of the attribute can be customised using the `edge_weight_attr` parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  default: [4]\n",
      "    Features: none\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [5]\n",
      "        Weights: range=[1, 5], mean=3, std=1.58114\n"
     ]
    }
   ],
   "source": [
    "g_weighted_other = nx.Graph()\n",
    "g_weighted_other.add_edge(\"a\", \"b\", distance=1.0)\n",
    "g_weighted_other.add_edge(\"b\", \"c\", distance=2.0)\n",
    "g_weighted_other.add_edge(\"c\", \"d\", distance=3.0)\n",
    "g_weighted_other.add_edge(\"d\", \"a\", distance=4.0)\n",
    "# diagonal\n",
    "g_weighted_other.add_edge(\"a\", \"c\", distance=5.0)\n",
    "\n",
    "square_weighted_other = StellarGraph.from_networkx(\n",
    "    g_weighted, edge_weight_attr=\"distance\"\n",
    ")\n",
    "print(square_weighted.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Directed graphs\n",
    "\n",
    "Some graphs have edge directions, where going from source to target has a different meaning to going from target to source. NetworkX models this using the `DiGraph` and `MultiDiGraph` classes, and `StellarGraph.from_networkx` automatically creates a directed graph if they are passed.\n",
    "\n",
    "All of the other options like node features and edge weights work the same as undirected graphs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarDiGraph: Directed multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  default: [4]\n",
      "    Features: none\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [5]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "g_directed = nx.DiGraph()\n",
    "g_directed.add_edge(\"a\", \"b\")\n",
    "g_directed.add_edge(\"b\", \"c\")\n",
    "g_directed.add_edge(\"c\", \"d\")\n",
    "g_directed.add_edge(\"d\", \"a\")\n",
    "# diagonal\n",
    "g_directed.add_edge(\"a\", \"c\")\n",
    "\n",
    "square_directed = StellarGraph.from_networkx(g_directed)\n",
    "print(square_directed.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Heterogeneous graphs\n",
    "\n",
    "Some graphs have multiple types of nodes and multiple types of edges.\n",
    "\n",
    "`StellarGraph` supports all forms of heterogeneous graphs, including knowledge graphs.\n",
    "\n",
    "The types of nodes and edges in a heterogeneous graph created using `StellarGraph.from_networkx` are read from a `label` attribute by default. \n",
    "\n",
    "### Multiple node types\n",
    "\n",
    "Suppose `a` is of type `foo`, and `b`, `c` and `d` are of type `bar`. We can set the `label` attribute on each node appropriate using the `nx.set_node_attributes` function ([docs](https://networkx.github.io/documentation/stable/reference/generated/networkx.classes.function.set_node_attributes.html))."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  bar: [3]\n",
      "    Features: none\n",
      "    Edge types: bar-default->bar, bar-default->foo\n",
      "  foo: [1]\n",
      "    Features: none\n",
      "    Edge types: foo-default->bar\n",
      "\n",
      " Edge types:\n",
      "    foo-default->bar: [3]\n",
      "        Weights: all 1 (default)\n",
      "    bar-default->bar: [2]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "g_foo_bar = g.copy()\n",
    "nx.set_node_attributes(\n",
    "    g_foo_bar, {\"a\": \"foo\", \"b\": \"bar\", \"c\": \"bar\", \"d\": \"bar\"}, \"label\"\n",
    ")\n",
    "\n",
    "square_foo_bar = StellarGraph.from_networkx(g_foo_bar)\n",
    "print(square_foo_bar.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If the `label` attribute doesn't exist, the `node_type_default` value is used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  bar: [3]\n",
      "    Features: none\n",
      "    Edge types: bar-default->bar, bar-default->foo\n",
      "  foo: [1]\n",
      "    Features: none\n",
      "    Edge types: foo-default->bar\n",
      "\n",
      " Edge types:\n",
      "    foo-default->bar: [3]\n",
      "        Weights: all 1 (default)\n",
      "    bar-default->bar: [2]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "g_foo = g.copy()\n",
    "# only 'a' has a label attribute\n",
    "g_foo.nodes[\"a\"][\"label\"] = \"foo\"\n",
    "\n",
    "square_foo_bar_default = StellarGraph.from_networkx(g_foo, node_type_default=\"bar\")\n",
    "print(square_foo_bar_default.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The attribute used to compute the node or edge type can be customised via the `node_type_attr` parameter. For instance, we can use the `type` attribute instead of the `label` one:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  bar: [4]\n",
      "    Features: none\n",
      "    Edge types: bar-default->bar\n",
      "\n",
      " Edge types:\n",
      "    bar-default->bar: [5]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "g_foo_other = g.copy()\n",
    "# only 'a' has a type attribute\n",
    "g_foo_other.nodes[\"a\"][\"type\"] = \"foo\"\n",
    "\n",
    "square_foo_bar_other = StellarGraph.from_networkx(\n",
    "    g_foo, node_type_default=\"bar\", node_type_attr=\"type\"\n",
    ")\n",
    "print(square_foo_bar_other.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we have features for the nodes, the features can be stored in the nodes under a features attribute. The name of the attribute has to be the same for all types, but the shape or size of the attribute does not."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "g_foo_bar_attr = g_foo_bar.copy()\n",
    "nx.set_node_attributes(\n",
    "    g_foo_bar_attr,\n",
    "    {\"a\": [], \"b\": [0.4, 100], \"c\": [0.1, 200], \"d\": [0.9, 300]},\n",
    "    \"feature\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  bar: [3]\n",
      "    Features: float32 vector, length 2\n",
      "    Edge types: bar-default->bar, bar-default->foo\n",
      "  foo: [1]\n",
      "    Features: none\n",
      "    Edge types: foo-default->bar\n",
      "\n",
      " Edge types:\n",
      "    foo-default->bar: [3]\n",
      "        Weights: all 1 (default)\n",
      "    bar-default->bar: [2]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "square_foo_bar_features_attr = StellarGraph.from_networkx(\n",
    "    g_foo_bar_attr, node_features=\"feature\"\n",
    ")\n",
    "print(square_foo_bar_features_attr.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice how the nodes of type `foo` (`a`) have no features, but the nodes of type `bar` (all others) have a vector of length 2.\n",
    "\n",
    "Alternatively, we can use Pandas DataFrames, specifying the `node_features=...` parameter as a dictionary mapping a node type to a DataFrame of features for nodes of that type. The dictionary only needs to include node types that have features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>y</th>\n",
       "      <th>z</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>b</th>\n",
       "      <td>0.4</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>c</th>\n",
       "      <td>0.1</td>\n",
       "      <td>200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>d</th>\n",
       "      <td>0.9</td>\n",
       "      <td>300</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     y    z\n",
       "b  0.4  100\n",
       "c  0.1  200\n",
       "d  0.9  300"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "features_bar = pd.DataFrame(\n",
    "    {\"y\": [0.4, 0.1, 0.9], \"z\": [100, 200, 300]}, index=[\"b\", \"c\", \"d\"]\n",
    ")\n",
    "features_bar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  bar: [3]\n",
      "    Features: float32 vector, length 2\n",
      "    Edge types: bar-default->bar, bar-default->foo\n",
      "  foo: [1]\n",
      "    Features: none\n",
      "    Edge types: foo-default->bar\n",
      "\n",
      " Edge types:\n",
      "    foo-default->bar: [3]\n",
      "        Weights: all 1 (default)\n",
      "    bar-default->bar: [2]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "square_foo_bar_features_dataframe = StellarGraph.from_networkx(\n",
    "    g_foo_bar, node_features={\"bar\": features_bar}\n",
    ")\n",
    "print(square_foo_bar_features_dataframe.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiple edge types\n",
    "\n",
    "A heterogeneous graph with multiple edge types is supported in the same way, by looking for a `label` attribute (the name can be customised with the `edge_type_attr=...` parameter, like `node_type_attr=...`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "g_orientation = nx.Graph()\n",
    "g_orientation.add_edge(\"a\", \"b\", label=\"horizontal\")\n",
    "g_orientation.add_edge(\"b\", \"c\", label=\"vertical\")\n",
    "g_orientation.add_edge(\"c\", \"d\", label=\"horizontal\")\n",
    "g_orientation.add_edge(\"d\", \"a\", label=\"vertical\")\n",
    "g_orientation.add_edge(\"a\", \"c\", label=\"diagonal\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  default: [4]\n",
      "    Features: none\n",
      "    Edge types: default-diagonal->default, default-horizontal->default, default-vertical->default\n",
      "\n",
      " Edge types:\n",
      "    default-vertical->default: [2]\n",
      "        Weights: all 1 (default)\n",
      "    default-horizontal->default: [2]\n",
      "        Weights: all 1 (default)\n",
      "    default-diagonal->default: [1]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "square_orientation = StellarGraph.from_networkx(g_orientation)\n",
    "print(square_orientation.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiple everything\n",
    "\n",
    "A graph can have multiple node types and multiple edge types, with features or without and with edge weights or without. We can put everything together from the previous sections to make a single complicated `StellarGraph`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "g_everything = g_orientation.copy()\n",
    "nx.set_node_attributes(\n",
    "    g_everything, {\"a\": \"foo\", \"b\": \"bar\", \"c\": \"bar\", \"d\": \"bar\"}, \"label\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  bar: [3]\n",
      "    Features: float32 vector, length 2\n",
      "    Edge types: bar-diagonal->foo, bar-horizontal->bar, bar-horizontal->foo, bar-vertical->bar, bar-vertical->foo\n",
      "  foo: [1]\n",
      "    Features: none\n",
      "    Edge types: foo-diagonal->bar, foo-horizontal->bar, foo-vertical->bar\n",
      "\n",
      " Edge types:\n",
      "    foo-vertical->bar: [1]\n",
      "        Weights: all 1 (default)\n",
      "    foo-horizontal->bar: [1]\n",
      "        Weights: all 1 (default)\n",
      "    foo-diagonal->bar: [1]\n",
      "        Weights: all 1 (default)\n",
      "    bar-vertical->bar: [1]\n",
      "        Weights: all 1 (default)\n",
      "    bar-horizontal->bar: [1]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "stellar_everything = StellarGraph.from_networkx(\n",
    "    g_everything, node_features={\"bar\": features_bar}\n",
    ")\n",
    "print(stellar_everything.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "You hopefully now know more about building a `StellarGraph` in various configurations via NetworkX.\n",
    "\n",
    "Revisit this document to use as a reminder, or [documentation](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.StellarGraph.from_networkx) for the `StellarGraph.from_networkx` static method.\n",
    "\n",
    "Once you've loaded your data, you can start doing machine learning: a good place to start is the [demo of the GCN algorithm on the Cora dataset for node classification](../node-classification/gcn-node-classification.ipynb). Additionally, StellarGraph includes [many other demos of other algorithms, solving other tasks](../README.md)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/basics/loading-networkx.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/basics/loading-networkx.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
